import pickle
import jax
import matplotlib.pyplot as plt
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from scipy.stats import gaussian_kde
import plotly.express as px
import pandas as pd
import pickle
tfd = tfp.distributions
import plotly
plotly.offline.init_notebook_mode()
x = jnp.linspace(0,6,10000)
with open('./results_data/linear_regression_Ajax','rb') as f:
variational = pickle.load(f)
params = variational.get_params()
loc_m, scale = jax.tree_leaves(variational.transform_dist(params['theta']))
scale = jnp.dot(scale, scale.T)
loc_m,scale
(DeviceArray([3.7587612, 2.1869001], dtype=float32),
DeviceArray([[ 0.13721034, -0.04376228],
[-0.04376228, 0.03969986]], dtype=float32))
all_pdf = []
for i in range(2):
y = tfd.Normal(loc = loc_m[i],scale = jnp.sqrt(scale[i][i])).prob(x)
all_pdf.append(y)
with open('./results_data/linear_regression_laplace','rb') as f:
laplace = pickle.load(f)
loc_m = laplace['mean']
std = jnp.sqrt(jnp.diag(laplace['cov']))
for i in range(2):
y = tfd.Normal(loc = loc_m[i],scale = std[i]).prob(x)
all_pdf.append(y)
with open('./results_data/MCMC_Blackjax','rb') as f:
black_samples = pickle.load(f)
for i in range(2):
kde_black = gaussian_kde(black_samples.position['theta'][:,i])
pdf_black = kde_black(x)
all_pdf.append(pdf_black)
all_label = ['Ajax VI theta0']*x.shape[0] + ['Ajax VI theta1']*x.shape[0] + ['Laplace theta0']*x.shape[0] + ['Laplace theta1']*x.shape[0] +['MCMC theta0']*x.shape[0]+['MCMC theta1']*x.shape[0]
all_pdf = jnp.array(all_pdf).reshape((-1))
x_repeated = jnp.tile(x,6)
to_df = {
"theta":x_repeated,
"PDF":all_pdf,
"label": all_label
}
df = pd.DataFrame(to_df)
fig = px.line(to_df,"theta","PDF",color="label",title="linear regression")
fig.show()
fig.write_html("linear_reg_result_plotly.html")